#include <iostream>

using namespace std;

struct TreeNode {
    int val;
    TreeNode* left;
    TreeNode* right;
    TreeNode() : val(0), left(nullptr), right(nullptr) {}
    TreeNode(int x) : val(x), left(nullptr), right(nullptr) {}
    TreeNode(int x, TreeNode* left, TreeNode* right) : val(x), left(left), right(right) {}
};

class Solution {
    int count;
    int ret;

public:
    int kthSmallest(TreeNode* root, int k) {
        count = k;
        dfs(root);
        return ret;
    }

    void dfs(TreeNode* root) {
        if (root == nullptr) return;

        dfs(root->left);
        count--;
        if (count == 0) ret = root->val;
        dfs(root->right);
    }
};

int main() {
    TreeNode* root = new TreeNode(3);
    root->left = new TreeNode(1);
    root->right = new TreeNode(4);
    root->left->right = new TreeNode(2);

    Solution solution;
    int kth_smallest = solution.kthSmallest(root, 1);

    cout << "The kth smallest element is: " << kth_smallest << endl;

    return 0;
}